/*
 *   This program is free software: you can redistribute it and/or modify
 *   it under the terms of the GNU General Public License as published by
 *   the Free Software Foundation, either version 3 of the License, or
 *   (at your option) any later version.
 *
 *   This program is distributed in the hope that it will be useful,
 *   but WITHOUT ANY WARRANTY; without even the implied warranty of
 *   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *   GNU General Public License for more details.
 *
 *   You should have received a copy of the GNU General Public License
 *   along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

/*
 *    NaiveBayesSimple.java
 *    Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
 *
 */

package weka.classifiers.bayes;

import java.util.Enumeration;

import weka.classifiers.AbstractClassifier;
import weka.core.Attribute;
import weka.core.Capabilities;
import weka.core.Capabilities.Capability;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformation.Field;
import weka.core.TechnicalInformation.Type;
import weka.core.TechnicalInformationHandler;
import weka.core.Utils;

/**
 * <!-- globalinfo-start --> Class for building and using a simple Naive Bayes
 * classifier.Numeric attributes are modelled by a normal distribution.<br/>
 * <br/>
 * For more information, see<br/>
 * <br/>
 * Richard Duda, Peter Hart (1973). Pattern Classification and Scene Analysis.
 * Wiley, New York.
 * <p/>
 * <!-- globalinfo-end -->
 * 
 * <!-- technical-bibtex-start --> BibTeX:
 * 
 * <pre>
 * &#64;book{Duda1973,
 *    address = {New York},
 *    author = {Richard Duda and Peter Hart},
 *    publisher = {Wiley},
 *    title = {Pattern Classification and Scene Analysis},
 *    year = {1973}
 * }
 * </pre>
 * <p/>
 * <!-- technical-bibtex-end -->
 * 
 * <!-- options-start --> Valid options are:
 * <p/>
 * 
 * <pre>
 * -D
 *  If set, classifier is run in debug mode and
 *  may output additional info to the console
 * </pre>
 * 
 * <!-- options-end -->
 * 
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version $Revision$
 */
public class NaiveBayesSimple extends AbstractClassifier implements
  TechnicalInformationHandler {

  /** for serialization */
  static final long serialVersionUID = -1478242251770381214L;

  /** All the counts for nominal attributes. */
  protected double[][][] m_Counts;

  /** The means for numeric attributes. */
  protected double[][] m_Means;

  /** The standard deviations for numeric attributes. */
  protected double[][] m_Devs;

  /** The prior probabilities of the classes. */
  protected double[] m_Priors;

  /** The instances used for training. */
  protected Instances m_Instances;

  /** Constant for normal distribution. */
  protected static double NORM_CONST = Math.sqrt(2 * Math.PI);

  /**
   * Returns a string describing this classifier
   * 
   * @return a description of the classifier suitable for displaying in the
   *         explorer/experimenter gui
   */
  public String globalInfo() {
    return "Class for building and using a simple Naive Bayes classifier."
      + "Numeric attributes are modelled by a normal distribution.\n\n"
      + "For more information, see\n\n" + getTechnicalInformation().toString();
  }

  /**
   * Returns an instance of a TechnicalInformation object, containing detailed
   * information about the technical background of this class, e.g., paper
   * reference or book this class is based on.
   * 
   * @return the technical information about this class
   */
  @Override
  public TechnicalInformation getTechnicalInformation() {
    TechnicalInformation result;

    result = new TechnicalInformation(Type.BOOK);
    result.setValue(Field.AUTHOR, "Richard Duda and Peter Hart");
    result.setValue(Field.YEAR, "1973");
    result.setValue(Field.TITLE, "Pattern Classification and Scene Analysis");
    result.setValue(Field.PUBLISHER, "Wiley");
    result.setValue(Field.ADDRESS, "New York");

    return result;
  }

  /**
   * Returns default capabilities of the classifier.
   * 
   * @return the capabilities of this classifier
   */
  @Override
  public Capabilities getCapabilities() {
    Capabilities result = super.getCapabilities();
    result.disableAll();

    // attributes
    result.enable(Capability.NOMINAL_ATTRIBUTES);
    result.enable(Capability.NUMERIC_ATTRIBUTES);
    result.enable(Capability.DATE_ATTRIBUTES);
    result.enable(Capability.MISSING_VALUES);

    // class
    result.enable(Capability.NOMINAL_CLASS);
    result.enable(Capability.MISSING_CLASS_VALUES);

    return result;
  }

  /**
   * Generates the classifier.
   * 
   * @param instances set of instances serving as training data
   * @exception Exception if the classifier has not been generated successfully
   */
  @Override
  public void buildClassifier(Instances instances) throws Exception {

    int attIndex = 0;
    double sum;

    // can classifier handle the data?
    getCapabilities().testWithFail(instances);

    // remove instances with missing class
    instances = new Instances(instances);
    instances.deleteWithMissingClass();

    m_Instances = new Instances(instances, 0);

    // Reserve space
    m_Counts = new double[instances.numClasses()][instances.numAttributes() - 1][0];
    m_Means = new double[instances.numClasses()][instances.numAttributes() - 1];
    m_Devs = new double[instances.numClasses()][instances.numAttributes() - 1];
    m_Priors = new double[instances.numClasses()];
    Enumeration<Attribute> enu = instances.enumerateAttributes();
    while (enu.hasMoreElements()) {
      Attribute attribute = enu.nextElement();
      if (attribute.isNominal()) {
        for (int j = 0; j < instances.numClasses(); j++) {
          m_Counts[j][attIndex] = new double[attribute.numValues()];
        }
      } else {
        for (int j = 0; j < instances.numClasses(); j++) {
          m_Counts[j][attIndex] = new double[1];
        }
      }
      attIndex++;
    }

    // Compute counts and sums
    Enumeration<Instance> enumInsts = instances.enumerateInstances();
    while (enumInsts.hasMoreElements()) {
      Instance instance = enumInsts.nextElement();
      if (!instance.classIsMissing()) {
        Enumeration<Attribute> enumAtts = instances.enumerateAttributes();
        attIndex = 0;
        while (enumAtts.hasMoreElements()) {
          Attribute attribute = enumAtts.nextElement();
          if (!instance.isMissing(attribute)) {
            if (attribute.isNominal()) {
              m_Counts[(int) instance.classValue()][attIndex][(int) instance
                .value(attribute)]++;
            } else {
              m_Means[(int) instance.classValue()][attIndex] += instance
                .value(attribute);
              m_Counts[(int) instance.classValue()][attIndex][0]++;
            }
          }
          attIndex++;
        }
        m_Priors[(int) instance.classValue()]++;
      }
    }

    // Compute means
    Enumeration<Attribute> enumAtts = instances.enumerateAttributes();
    attIndex = 0;
    while (enumAtts.hasMoreElements()) {
      Attribute attribute = enumAtts.nextElement();
      if (attribute.isNumeric()) {
        for (int j = 0; j < instances.numClasses(); j++) {
          if (m_Counts[j][attIndex][0] < 2) {
            throw new Exception("attribute " + attribute.name()
              + ": less than two values for class "
              + instances.classAttribute().value(j));
          }
          m_Means[j][attIndex] /= m_Counts[j][attIndex][0];
        }
      }
      attIndex++;
    }

    // Compute standard deviations
    enumInsts = instances.enumerateInstances();
    while (enumInsts.hasMoreElements()) {
      Instance instance = enumInsts.nextElement();
      if (!instance.classIsMissing()) {
        enumAtts = instances.enumerateAttributes();
        attIndex = 0;
        while (enumAtts.hasMoreElements()) {
          Attribute attribute = enumAtts.nextElement();
          if (!instance.isMissing(attribute)) {
            if (attribute.isNumeric()) {
              m_Devs[(int) instance.classValue()][attIndex] += (m_Means[(int) instance
                .classValue()][attIndex] - instance.value(attribute))
                * (m_Means[(int) instance.classValue()][attIndex] - instance
                  .value(attribute));
            }
          }
          attIndex++;
        }
      }
    }
    enumAtts = instances.enumerateAttributes();
    attIndex = 0;
    while (enumAtts.hasMoreElements()) {
      Attribute attribute = enumAtts.nextElement();
      if (attribute.isNumeric()) {
        for (int j = 0; j < instances.numClasses(); j++) {
          if (m_Devs[j][attIndex] <= 0) {
            throw new Exception("attribute " + attribute.name()
              + ": standard deviation is 0 for class "
              + instances.classAttribute().value(j));
          } else {
            m_Devs[j][attIndex] /= m_Counts[j][attIndex][0] - 1;
            m_Devs[j][attIndex] = Math.sqrt(m_Devs[j][attIndex]);
          }
        }
      }
      attIndex++;
    }

    // Normalize counts
    enumAtts = instances.enumerateAttributes();
    attIndex = 0;
    while (enumAtts.hasMoreElements()) {
      Attribute attribute = enumAtts.nextElement();
      if (attribute.isNominal()) {
        for (int j = 0; j < instances.numClasses(); j++) {
          sum = Utils.sum(m_Counts[j][attIndex]);
          for (int i = 0; i < attribute.numValues(); i++) {
            m_Counts[j][attIndex][i] = (m_Counts[j][attIndex][i] + 1)
              / (sum + attribute.numValues());
          }
        }
      }
      attIndex++;
    }

    // Normalize priors
    sum = Utils.sum(m_Priors);
    for (int j = 0; j < instances.numClasses(); j++) {
      m_Priors[j] = (m_Priors[j] + 1) / (sum + instances.numClasses());
    }
  }

  /**
   * Calculates the class membership probabilities for the given test instance.
   * 
   * @param instance the instance to be classified
   * @return predicted class probability distribution
   * @exception Exception if distribution can't be computed
   */
  @Override
  public double[] distributionForInstance(Instance instance) throws Exception {

    double[] probs = new double[instance.numClasses()];
    int attIndex;

    for (int j = 0; j < instance.numClasses(); j++) {
      probs[j] = 1;
      Enumeration<Attribute> enumAtts = instance.enumerateAttributes();
      attIndex = 0;
      while (enumAtts.hasMoreElements()) {
        Attribute attribute = enumAtts.nextElement();
        if (!instance.isMissing(attribute)) {
          if (attribute.isNominal()) {
            probs[j] *= m_Counts[j][attIndex][(int) instance.value(attribute)];
          } else {
            probs[j] *= normalDens(instance.value(attribute),
              m_Means[j][attIndex], m_Devs[j][attIndex]);
          }
        }
        attIndex++;
      }
      probs[j] *= m_Priors[j];
    }

    // Normalize probabilities
    Utils.normalize(probs);

    return probs;
  }

  /**
   * Returns a description of the classifier.
   * 
   * @return a description of the classifier as a string.
   */
  @Override
  public String toString() {

    if (m_Instances == null) {
      return "Naive Bayes (simple): No model built yet.";
    }
    try {
      StringBuffer text = new StringBuffer("Naive Bayes (simple)");
      int attIndex;

      for (int i = 0; i < m_Instances.numClasses(); i++) {
        text.append("\n\nClass " + m_Instances.classAttribute().value(i)
          + ": P(C) = " + Utils.doubleToString(m_Priors[i], 10, 8) + "\n\n");
        Enumeration<Attribute> enumAtts = m_Instances.enumerateAttributes();
        attIndex = 0;
        while (enumAtts.hasMoreElements()) {
          Attribute attribute = enumAtts.nextElement();
          text.append("Attribute " + attribute.name() + "\n");
          if (attribute.isNominal()) {
            for (int j = 0; j < attribute.numValues(); j++) {
              text.append(attribute.value(j) + "\t");
            }
            text.append("\n");
            for (int j = 0; j < attribute.numValues(); j++) {
              text.append(Utils.doubleToString(m_Counts[i][attIndex][j], 10, 8)
                + "\t");
            }
          } else {
            text.append("Mean: "
              + Utils.doubleToString(m_Means[i][attIndex], 10, 8) + "\t");
            text.append("Standard Deviation: "
              + Utils.doubleToString(m_Devs[i][attIndex], 10, 8));
          }
          text.append("\n\n");
          attIndex++;
        }
      }

      return text.toString();
    } catch (Exception e) {
      return "Can't print Naive Bayes classifier!";
    }
  }

  /**
   * Density function of normal distribution.
   * 
   * @param x the value to get the density for
   * @param mean the mean
   * @param stdDev the standard deviation
   * @return the density
   */
  protected double normalDens(double x, double mean, double stdDev) {

    double diff = x - mean;

    return (1 / (NORM_CONST * stdDev))
      * Math.exp(-(diff * diff / (2 * stdDev * stdDev)));
  }

  /**
   * Returns the revision string.
   * 
   * @return the revision
   */
  @Override
  public String getRevision() {
    return RevisionUtils.extract("$Revision$");
  }

  /**
   * Main method for testing this class.
   * 
   * @param argv the options
   */
  public static void main(String[] argv) {
    runClassifier(new NaiveBayesSimple(), argv);
  }
}
